# Experiments with contrastive learning

This codebase supports zeef@ and nishanthd@'s experiments with contrastive learning and images defined by sets of of latent values. Our goal is to explore whether we can recover the latent values from representations generated by self-supervised models such as SimCLR, inspired by the work of [arXiv:2103.06875](https://arxiv.org/abs/2103.06875), [arXiv:2102.08850](https://arxiv.org/abs/2102.08850).

The code is based on Tensorflow 2.x and uses both tf.data.Datasets and pandas DataFrames to manage the datasets.

## Datasets

We focus on [dsprites](https://www.tensorflow.org/datasets/catalog/dsprites) and [3dident](https://github.com/brendel-group/cl-ica), since both consist of images defined by a set of latent values. However, while dsprites is available through tfds, the tensorflow dataset format makes it difficult to look up specific examples or search for similar examples (needed for generating positive pairs for contrastive training) and 3dident is not currently available through tfds at all.

`datasets.py` contains functions to convert dsprites to a pandas dataframe format, load either dsprites or 3dident with a standardized format ready for experiments, and generate and load contrastive training sets from either dataset. There is currently no function to convert 3dident to this format, but it can be downloaded from [here](https://zenodo.org/record/4502485#.YecGYU3MIjg), extracted to the folder of your choosing, and then loaded using `datasets.get_standard_dataset`.

`data_utils.py` contains helper functions for working with both datasets, including dataset-specific functions for preprocessing and searching for similar examples.

Other datasets could easily be added to this framework by mimicking the preprocessing functions in `data_utils.py` and adding the appropriate logic to `datasets.get_standard_dataset`.

## Training

`train_linear_layer.py` implements a custom training loop to train a linear layer on top of a pre-trained SimCLR model. It can be run on either a single GPU (although the limiting factor here is whether the pre-trained SimCLR can fit in the GPU's memory) or on multiple TPUs via `tf.distribute`.

We provide helper functions for measuring training progress on the dsprites dataset in `metrics_utils.py`. These include custom accuracy measurements (e.g. defining a latent value to be "accurately" predicted if it is closer to the correct latent value than adjacent ones) and a class to handle metrics logging of each latent during training.

Further metrics can easily be added by modifying the `setup_metrics` and `update_metrics` methods of the appropriate MetricsInterface class.

To launch an experiment, use `python -m train_linear_layer.py --pretrained_model_path=...`, filling in the path to where your pretrained SimCLR (or equivalent) model weights are stored.
